{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6953b78",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext lab_black"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58808130",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import sys\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33bc02b3",
   "metadata": {},
   "source": [
    "## Loading Model and Vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9289bf5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder = \"weights/skipgram_WikiText2\"\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model = torch.load(f\"../{folder}/model.pt\", map_location=device)\n",
    "vocab = torch.load(f\"../{folder}/vocab.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f93a095",
   "metadata": {},
   "source": [
    "## Getting Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb8fec7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# embedding from first model layer\n",
    "embeddings = list(model.parameters())[0]\n",
    "embeddings = embeddings.cpu().detach().numpy()\n",
    "\n",
    "# normalization\n",
    "norms = (embeddings ** 2).sum(axis=1) ** (1 / 2)\n",
    "norms = np.reshape(norms, (len(norms), 1))\n",
    "embeddings_norm = embeddings / norms\n",
    "embeddings_norm.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53d61ffe",
   "metadata": {},
   "source": [
    "# Visualization with t-SNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5296057c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get embeddings\n",
    "embeddings_df = pd.DataFrame(embeddings)\n",
    "\n",
    "# t-SNE transform\n",
    "tsne = TSNE(n_components=2)\n",
    "embeddings_df_trans = tsne.fit_transform(embeddings_df)\n",
    "embeddings_df_trans = pd.DataFrame(embeddings_df_trans)\n",
    "\n",
    "# get token order\n",
    "embeddings_df_trans.index = vocab.get_itos()\n",
    "\n",
    "# if token is a number\n",
    "is_numeric = embeddings_df_trans.index.str.isnumeric()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f53e4ea5",
   "metadata": {},
   "outputs": [],
   "source": [
    "color = np.where(is_numeric, \"green\", \"black\")\n",
    "fig = go.Figure()\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=embeddings_df_trans[0],\n",
    "        y=embeddings_df_trans[1],\n",
    "        mode=\"text\",\n",
    "        text=embeddings_df_trans.index,\n",
    "        textposition=\"middle center\",\n",
    "        textfont=dict(color=color),\n",
    "    )\n",
    ")\n",
    "fig.write_html(\"../word2vec_visualization.html\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c1790ce",
   "metadata": {},
   "source": [
    "# Find Similar Words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa629850",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_top_similar(word: str, topN: int = 10):\n",
    "    word_id = vocab[word]\n",
    "    if word_id == 0:\n",
    "        print(\"Out of vocabulary word\")\n",
    "        return\n",
    "\n",
    "    word_vec = embeddings_norm[word_id]\n",
    "    word_vec = np.reshape(word_vec, (len(word_vec), 1))\n",
    "    dists = np.matmul(embeddings_norm, word_vec).flatten()\n",
    "    topN_ids = np.argsort(-dists)[1 : topN + 1]\n",
    "\n",
    "    topN_dict = {}\n",
    "    for sim_word_id in topN_ids:\n",
    "        sim_word = vocab.lookup_token(sim_word_id)\n",
    "        topN_dict[sim_word] = dists[sim_word_id]\n",
    "    return topN_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e409d821",
   "metadata": {},
   "outputs": [],
   "source": [
    "for word, sim in get_top_similar(\"gernamy\").items():\n",
    "    print(\"{}: {:.3f}\".format(word, sim))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf37551f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "053a227b",
   "metadata": {},
   "source": [
    "# Vector Equations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2cffb0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "emb1 = embeddings[vocab[\"king\"]]\n",
    "emb2 = embeddings[vocab[\"man\"]]\n",
    "emb3 = embeddings[vocab[\"woman\"]]\n",
    "\n",
    "emb4 = emb1 - emb2 + emb3\n",
    "emb4_norm = (emb4 ** 2).sum() ** (1 / 2)\n",
    "emb4 = emb4 / emb4_norm\n",
    "\n",
    "emb4 = np.reshape(emb4, (len(emb4), 1))\n",
    "dists = np.matmul(embeddings_norm, emb4).flatten()\n",
    "\n",
    "top5 = np.argsort(-dists)[:5]\n",
    "\n",
    "for word_id in top5:\n",
    "    print(\"{}: {:.3f}\".format(vocab.lookup_token(word_id), dists[word_id]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b71d8aa3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "environment": {
   "name": "tf2-gpu.2-3.m74",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-3:m74"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
